import torchvision.datasets as datasets
from torchvision.datasets.folder import IMG_EXTENSIONS, default_loader
from typing import Dict, List, Tuple


class ImageNetSubClass(datasets.ImageFolder):
    def __init__(self, root, transform= None, target_transform = None, loader = default_loader, is_valid_file = None, class_num=1000):
        self.class_num=class_num
        super().__init__(root, transform=transform, target_transform=target_transform, loader=loader, is_valid_file=is_valid_file)
        classes, class_to_idx = self._find_classes(self.root)
        extensions=IMG_EXTENSIONS if is_valid_file is None else None
        samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)
        if len(samples) == 0:
            msg = "Found 0 files in subfolders of: {}\n".format(self.root)
            if extensions is not None:
                msg += "Supported extensions are: {}".format(",".join(extensions))
            raise RuntimeError(msg)

        self.loader = loader
        self.extensions = extensions

        self.class_num = class_num
        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples
        self.targets = [s[1] for s in samples]

    def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]:
        """
        Finds the class folders in a dataset.

        Args:
            dir (string): Root directory path.

        Returns:
            tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.

        Ensures:
            No class is a subdirectory of another.
        """
        classes = [d.name for d in os.scandir(dir) if d.is_dir()]
        classes.sort()
        classes=classes[:self.class_num] # 몇개까지할건지
        class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
        return classes, class_to_idx
